from typing import Any
from torch.utils.data import Dataset

class TrainSubset(Dataset):
    def __init__(self, dataset: Dataset, indices=None):
        self.dataset = dataset
        if indices is None:
            self.indices = list(range(len(dataset)))
        else:
            self.indices = indices

    def __getitem__(self, index: int) -> Any:
        actual_index = self.indices[index]
        img, label = self.dataset[actual_index]
        return img, label  # 删除了多余的逗号

    def __len__(self) -> int:
        return len(self.indices)